NLLLossGrad

对 NLLLoss(Negative Log Likelihood Loss)算子的反向传播过程进行计算,得到输入 logits 的梯度。

该算子根据前向 NLLLoss 的 reduction_type,对上游梯度进行对应方式的反向分发,仅在真实类别索引处产生非零梯度,其余位置梯度为 0。

\[\begin{split}\frac{\partial L}{\partial x_{i,j}} = \begin{cases} - g_i \cdot w_{y_i}, & j = y_i,\ \text{reduction = none} \\ - g \cdot w_{y_i}, & j = y_i,\ \text{reduction = sum} \\ - g \cdot \dfrac{w_{y_i}}{\sum w}, & j = y_i,\ \text{reduction = mean} \\ 0, & j \neq y_i \end{cases}\end{split}\]

其中:

  • \(g_i\) 表示逐样本损失梯度

  • \(g\) 表示归约后的标量损失梯度

  • \(y_i\) 表示第 \(i\) 个样本的真实类别索引

  • \(w_{y_i}\) 表示对应类别权重

输入:
  • logits - 前向输入 logits 地址,形状为 [batch, class_num] (仅用于尺寸信息)。

  • loss_grad - 上游损失梯度地址:
    • reduction_type = 0 (none) 时,形状为 [batch]

    • reduction_type = 1 / 2 (sum / mean) 时,仅使用 loss_grad[0]

  • labels - 真实标签索引地址,形状为 [batch]

  • weight - 类别权重数组地址,形状为 [class_num]

  • total_weight - 权重和地址(仅 reduction_type = mean 时使用)。

  • batch - batch 大小。

  • class_num - 类别数量。

  • reduction_type - 损失归约方式:
    • 0:Sum

    • 1:Mean

    • 2:None

  • core_mask - 核掩码(仅适用于共享存储版本)。

输出:
  • logits_grad - logits 的梯度输出地址,形状为 [batch, class_num]

支持平台:

FT78NE MT7004

备注

  • FT78NE 仅支持 fp 类型

  • MT7004 支持 hp, fp 类型

  • 输出梯度在非真实类别位置恒为 0

  • labels 中索引需满足 0 <= label < class_num

共享存储版本:

void hp_nlllossgrad_s(half *logits, half *loss_grad, int *labels, half *weight, half *total_weight, half *logits_grad, int batch, int class_num, int reduction_type, int core_mask)
void fp_nlllossgrad_s(float *logits, float *loss_grad, int *labels, float *weight, float *total_weight, float *logits_grad, int batch, int class_num, int reduction_type, int core_mask)

C调用示例:

 1//FT78NE示例
 2#include <stdio.h>
 3#include <nlllossgrad.h>
 4
 5int main(int argc, char* argv[]) {
 6    float *logits       = (float *)0xA0000000;   // [batch, class_num]
 7    float *loss_grad    = (float *)0xA0001000;   // 上游梯度
 8    int   *labels       = (int *)0xA0002000;     // [batch]
 9    float *weight       = (float *)0xA0003000;   // [class_num]
10    float *total_weight = (float *)0xA0004000;
11    float *logits_grad  = (float *)0xC0000000;
12
13    int batch = 32;
14    int class_num = 1000;
15    int reduction_type = 1;  // Mean
16    int core_mask = 0xff;
17
18    fp_nlllossgrad_s(logits, loss_grad, labels, weight, total_weight,
19                     logits_grad, batch, class_num, reduction_type, core_mask);
20    return 0;
21}

私有存储版本:

void hp_nlllossgrad_p(half *logits, half *loss_grad, int *labels, half *weight, half *total_weight, half *logits_grad, int batch, int class_num, int reduction_type)
void fp_nlllossgrad_p(float *logits, float *loss_grad, int *labels, float *weight, float *total_weight, float *logits_grad, int batch, int class_num, int reduction_type)

C调用示例:

 1//FT78NE示例
 2#include <stdio.h>
 3#include <nlllossgrad.h>
 4
 5int main(int argc, char* argv[]) {
 6    float *logits       = (float *)0x10810000;   // L2空间
 7    float *loss_grad    = (float *)0x10820000;
 8    int   *labels       = (int *)0x10830000;
 9    float *weight       = (float *)0x10840000;
10    float *total_weight = (float *)0x10850000;
11    float *logits_grad  = (float *)0x10860000;
12
13    int batch = 32;
14    int class_num = 1000;
15    int reduction_type = 2;  // None
16
17    fp_nlllossgrad_p(logits, loss_grad, labels, weight, total_weight,
18                     logits_grad, batch, class_num, reduction_type);
19    return 0;
20}